Petals to the Metal
Getting Started with TPUs on Kaggle!
import numpy as np
import pandas as pd
import seaborn as sns
import albumentations as A
import matplotlib.pyplot as plt
import os, gc, cv2, random, re
import warnings, math, sys, json, pprint, pdb
import tensorflow as tf
from tensorflow.keras import backend as K
import tensorflow_hub as hub
from sklearn.model_selection import train_test_split
warnings.simplefilter('ignore')
print(f"Using TensorFlow v{tf.__version__}")
#@title Accelerator type { run: "auto" }
DEVICE = 'TPU' #@param ["None", "'GPU'", "'TPU'"] {type:"raw", allow-input: true}
if DEVICE == "TPU":
print("connecting to TPU...")
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
print('Running on TPU ', tpu.master())
except ValueError:
print("Could not connect to TPU")
tpu = None
if tpu:
try:
print("initializing TPU ...")
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
print("TPU initialized")
except _:
print("failed to initialize TPU")
else:
DEVICE = "GPU"
if DEVICE != "TPU":
print("Using default strategy for CPU and single GPU")
strategy = tf.distribute.get_strategy()
if DEVICE == "GPU":
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
AUTOTUNE = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')
def seed_everything(seed=0):
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
def is_colab():
return 'google.colab' in str(get_ipython())
#@title ML Lifecycle { run: "auto", display-mode: "form" }
SEED = 16
DEBUG = False #@param {type:"boolean"}
TRAIN = True #@param {type:"boolean"}
INFERENCE = True #@param {type:"boolean"}
IS_COLAB = is_colab()
seed_everything(SEED)
if IS_COLAB:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
project_name = 'tpu-getting-started'
root_path = '/content/gdrive/MyDrive/' if IS_COLAB else '/'
input_path = f'{root_path}kaggle/input/{project_name}/'
working_path = f'{input_path}working/' if IS_COLAB else '/kaggle/working/'
os.makedirs(working_path, exist_ok=True)
os.chdir(working_path)
os.listdir(input_path)
GCS_PATTERN = 'gs://flowers-public/*/*.jpg'
GCS_OUTPUT = 'gs://flowers-public/tfrecords-jpeg-192x192-2/flowers'
SHARDS = 16
TARGET_SIZE = [192, 192]
CLASSES = [b'daisy', b'dandelion', b'roses', b'sunflowers', b'tulips']
def decode_image_and_label(filename):
bits = tf.io.read_file(filename)
image = tf.image.decode_jpeg(bits)
label = tf.strings.split(tf.expand_dims(filename, axis=-1), sep='/')
#label = tf.strings.split(filename, sep='/')
label = label.values[-2]
label = tf.cast((CLASSES==label), tf.int8)
return image, label
filenames = tf.data.Dataset.list_files(GCS_PATTERN, seed=16)
for x in filenames.take(10): print(x)
ds0 = filenames.map(decode_image_and_label, num_parallel_calls=AUTOTUNE)
def show_images(ds):
_,axs = plt.subplots(3,3,figsize=(16,16))
for ((x, y), ax) in zip(ds.take(9), axs.flatten()):
ax.imshow(x.numpy().astype(np.uint8))
ax.set_title(np.argmax(y))
ax.axis('off')
show_images(ds0)
def resize_and_crop_image(image, label):
# Resize and crop using "fill" algorithm:
# always make sure the resulting image
# is cut out from the source image so that
# it fills the TARGET_SIZE entirely with no
# black bars and a preserved aspect ratio.
w = tf.shape(image)[0]
h = tf.shape(image)[1]
tw = TARGET_SIZE[1]
th = TARGET_SIZE[0]
resize_crit = (w * th) / (h * tw)
image = tf.cond(resize_crit < 1,
lambda: tf.image.resize(image, [w*tw/w, h*tw/w]), # if true
lambda: tf.image.resize(image, [w*th/h, h*th/h]) # if false
)
nw = tf.shape(image)[0]
nh = tf.shape(image)[1]
image = tf.image.crop_to_bounding_box(image, (nw - tw) // 2, (nh - th) // 2, tw, th)
return image, label
ds1 = ds0.map(resize_and_crop_image, num_parallel_calls=AUTOTUNE)
show_images(ds1)
Speed test: too slow
Google Cloud Storage is capable of great throughput but has a per-file access penalty. Run the cell below and see that throughput is around 8 images per second. That is too slow. Training on thousands of individual files will not work. We have to use the TFRecord format to group files together.
%%time
for image,label in ds1.batch(8).take(10):
print("Image batch shape {} {}".format(
image.numpy().shape,
[np.argmax(lbl) for lbl in label.numpy()]))
def recompress_image(image, label):
height = tf.shape(image)[0]
width = tf.shape(image)[1]
image = tf.cast(image, tf.uint8)
image = tf.image.encode_jpeg(image, optimize_size=True, chroma_downsampling=False)
return image, label, height, width
IMAGE_SIZE = len(tf.io.gfile.glob(GCS_PATTERN))
SHARD_SIZE = math.ceil(1.0 * IMAGE_SIZE / SHARDS)
ds2 = ds1.map(recompress_image, num_parallel_calls=AUTOTUNE)
ds2 = ds2.batch(SHARD_SIZE) # sharding: there will be one "batch" of images per file
Why TFRecords?
TPUs have eight cores which act as eight independent workers. We can get data to each core more efficiently by splitting the dataset into multiple files or shards. This way, each core can grab an independent part of the data as it needs.
The most convenient kind of file to use for sharding in TensorFlow is a TFRecord. A TFRecord is a binary file that contains sequences of byte-strings. Data needs to be serialized (encoded as a byte-string) before being written into a TFRecord.
The most convenient way of serializing data in TensorFlow is to wrap the data with tf.Example. This is a record format based on Google's protobufs but designed for TensorFlow. It's more or less like a dict with some type annotations
x = tf.constant([[1,2], [3, 4]], dtype=tf.uint8)
print(x)
x_in_bytes = tf.io.serialize_tensor(x)
print(x_in_bytes)
print(tf.io.parse_tensor(x_in_bytes, out_type=tf.uint8))
A TFRecord is a sequence of bytes, so we have to turn our data into byte-strings before it can go into a TFRecord. We can use tf.io.serialize_tensor to turn a tensor into a byte-string and tf.io.parse_tensor to turn it back. It's important to keep track of your tensor's datatype (in this case tf.uint8) since you have to specify it when parsing the string back to a tensor again
gs:// domain to write to.
# return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings))
#
#def _int_feature(list_of_ints): # int64
# return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))
#
#def _float_feature(list_of_floats): # float32
# return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))
#
#def to_tfrecord(tfrec_filewriter, img_bytes, label, height, width):
# id = np.argmax(np.array(CLASSES)==label)
# one_hot = np.eye(len(CLASSES))[id]
# feature = {
# "image": _bytestring_feature([img_bytes]), # one image in the list
# "id": _int_feature([id]), # one class in the list
# "label": _bytestring_feature([label]), # fixed length (1) list of strings, the text label
# "size" : _int_feature([height, width]), # fixed length (2) list of ints
# "one_hot": _float_feature(one_hot.tolist())# variable length list of floats, n=len(CLASSES)
# }
# return tf.train.Example(features=tf.train.Features(feature=feature))
#for shard_id, (image, label, height, width) in ds2.enumerate():
# shard_size = image.numpy().shape[0]
# filename = GCS_OUTPUT + "{:02d}-{}tfrec".format(shard_id, shard_size)
#
# with tf.io.TFRecordWriter(filename) as outfile:
# for i in range(shard_size):
# example = to_tfrecord(out_file,
# image.numpy()[i],
# label.numpy()[i],
# height.numpy()[i],
# width.numpy()[i])
# out_file.write(example.SerializeToString())
# print("Wrote file {} containing {} records".format(filename, shard_size))
def read_tfrecord(example):
features = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string = bytestring (not text string)
"class": tf.io.FixedLenFeature([], tf.int64), # shape [] means scalar
# additional (not very useful) fields to demonstrate TFRecord writing/reading of different types of data
"label": tf.io.FixedLenFeature([], tf.string), # one bytestring
"size": tf.io.FixedLenFeature([2], tf.int64), # two integers
"one_hot_class": tf.io.VarLenFeature(tf.float32) # a certain number of floats
}
# decode the TFRecord
example = tf.io.parse_single_example(example, features)
# FixedLenFeature fields are now ready to use: exmple['size']
# VarLenFeature fields require additional sparse_to_dense decoding
image = tf.image.decode_jpeg(example['image'], channels=3)
image = tf.reshape(image, [*TARGET_SIZE, 3])
class_num = example['class']
label = example['label']
height = example['size'][0]
width = example['size'][1]
one_hot_class = tf.sparse.to_dense(example['one_hot_class'])
return image, class_num, label, height, width, one_hot_class
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False
filenames = tf.io.gfile.glob(GCS_OUTPUT + "*tfrec")
ds3 = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
ds3 = (ds3.with_options(option_no_order)
.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
.shuffle(30))
ds3_to_show = ds3.map(lambda image, id, label, height, width, one_hot: (image, label))
show_images(ds3_to_show)
%%time
for image, class_num, label, height, width, one_hot_class in ds3.batch(8).take(10):
print("Image batch shape {} {}".format(
image.numpy().shape,
[lbl.decode('utf8') for lbl in label.numpy()]))
BASE_MODEL = 'efficientnet_b3' #@param ["'efficientnet_b3'", "'efficientnet_b4'", "'efficientnet_b2'"] {type:"raw", allow-input: true}
HEIGHT = 300#@param {type:"number"}
WIDTH = 300#@param {type:"number"}
CHANNELS = 3#@param {type:"number"}
IMG_SIZE = (HEIGHT, WIDTH, CHANNELS)
EPOCHS = 12 #@param {type:"number"}
BATCH_SIZE = 16 * strategy.num_replicas_in_sync #@param {type:"raw"}
print("Using {} with input size {}".format(BASE_MODEL, IMG_SIZE))
GCS_PATH = 'gs://kds-e93303da9a97ef8fd254ceb5e9ed104470f247527dd45aba9685bdf5' #@param {type: "string"}
GCS_PATH += '/tfrecords-jpeg-512x512' #@param {type: "string"}
CLASSES = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'wild geranium', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', # 00 - 09
'snapdragon', "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower', 'giant white arum lily', # 10 - 19
'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william', # 20 - 29
'carnation', 'garden phlox', 'love in the mist', 'cosmos', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', # 30 - 39
'barberton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'daisy', 'common dandelion', # 40 - 49
'petunia', 'wild pansy', 'primula', 'sunflower', 'lilac hibiscus', 'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia', # 50 - 59
'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus', 'iris', 'windflower', 'tree poppy', # 60 - 69
'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', # 70 - 79
'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily', # 80 - 89
'hippeastrum ', 'bee balm', 'pink quill', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia', 'bromelia', 'blanket flower', # 90 - 99
'trumpet creeper', 'blackberry lily', 'common tulip', 'wild rose']
print(f"Sourcing images from {GCS_PATH}")
def decode_image(image_data):
image = tf.image.decode_jpeg(image_data, channels=CHANNELS)
image = tf.cast(image, tf.float32) / 255.0 # convert image to floats in [0, 1] range
image = tf.reshape(image, IMG_SIZE) # explicit size needed for TPU
return image
def collate_labeled_tfrecord(example):
LABELED_TFREC_FORMAT = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
"class": tf.io.FixedLenFeature([], tf.int64), # shape [] means single element
}
example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
image = decode_image(example['image'])
label = tf.cast(example['class'], tf.int32)
return image, label
def process_unlabeled_tfrecord(example):
UNLABELED_TFREC_FORMAT = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
"id": tf.io.FixedLenFeature([], tf.string), # shape [] means single element
}
example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
image = decode_image(example['image'])
idnum = example['id']
return image, idnum
def count_data_items(filenames):
n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1))
for filename in filenames]
return np.sum(n)
train_filenames = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
valid_filenames = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
test_filenames = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec')
# data augmentation @cdeotte kernel:
# https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96
def transform_rotation(image, height, rotation):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly rotated
DIM = height
XDIM = DIM%2 #fix for size 331
rotation = rotation * tf.random.uniform([1],dtype='float32')
# CONVERT DEGREES TO RADIANS
rotation = math.pi * rotation / 180.
# ROTATION MATRIX
c1 = tf.math.cos(rotation)
s1 = tf.math.sin(rotation)
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(rotation_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def transform_shear(image, height, shear):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly sheared
DIM = height
XDIM = DIM%2 #fix for size 331
shear = shear * tf.random.uniform([1],dtype='float32')
shear = math.pi * shear / 180.
# SHEAR MATRIX
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
c2 = tf.math.cos(shear)
s2 = tf.math.sin(shear)
shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def transform_shift(image, height, h_shift, w_shift):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly shifted
DIM = height
XDIM = DIM%2 #fix for size 331
height_shift = h_shift * tf.random.uniform([1],dtype='float32')
width_shift = w_shift * tf.random.uniform([1],dtype='float32')
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
# SHIFT MATRIX
shift_matrix = tf.reshape(tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(shift_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def data_augment(image, label):
p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_pixel = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_shift = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
# Flips
if p_spatial >= .2:
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
# Rotates
if p_rotate > .75:
image = tf.image.rot90(image, k=3) # rotate 270º
elif p_rotate > .5:
image = tf.image.rot90(image, k=2) # rotate 180º
elif p_rotate > .25:
image = tf.image.rot90(image, k=1) # rotate 90º
if p_rotation >= .3: # Rotation
image = transform_rotation(image, height=HEIGHT, rotation=45.)
if p_shift >= .3: # Shift
image = transform_shift(image, height=HEIGHT, h_shift=15., w_shift=15.)
if p_shear >= .3: # Shear
image = transform_shear(image, height=HEIGHT, shear=20.)
# Crops
if p_crop > .4:
crop_size = tf.random.uniform([], int(HEIGHT*.7), HEIGHT, dtype=tf.int32)
image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
elif p_crop > .7:
if p_crop > .9:
image = tf.image.central_crop(image, central_fraction=.7)
elif p_crop > .8:
image = tf.image.central_crop(image, central_fraction=.8)
else:
image = tf.image.central_crop(image, central_fraction=.9)
image = tf.image.resize(image, size=[HEIGHT, WIDTH])
# Pixel-level transforms
if p_pixel >= .2:
if p_pixel >= .8:
image = tf.image.random_saturation(image, lower=0, upper=2)
elif p_pixel >= .6:
image = tf.image.random_contrast(image, lower=.8, upper=2)
elif p_pixel >= .4:
image = tf.image.random_brightness(image, max_delta=.2)
else:
image = tf.image.adjust_gamma(image, gamma=.6)
return image, label
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False
train_ds = tf.data.TFRecordDataset(train_filenames, num_parallel_reads=AUTOTUNE)
train_ds = (train_ds
.map(collate_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
.map(data_augment, num_parallel_calls=AUTOTUNE)
.repeat()
.shuffle(2048)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE))
valid_ds = tf.data.TFRecordDataset(valid_filenames, num_parallel_reads=AUTOTUNE)
valid_ds = (valid_ds
.with_options(option_no_order)
.map(collate_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.cache()
.prefetch(AUTOTUNE))
test_ds = tf.data.TFRecordDataset(test_filenames, num_parallel_reads=AUTOTUNE)
test_ds = (test_ds
.with_options(option_no_order)
.map(process_unlabeled_tfrecord, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE))
data_augmentation = tf.keras.Sequential(
[
tf.keras.layers.experimental.preprocessing.RandomCrop(*IMG_SIZE),
tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.25),
tf.keras.layers.experimental.preprocessing.RandomZoom((-0.2, 0)),
tf.keras.layers.experimental.preprocessing.RandomContrast((0.2,0.2))
]
)
func = lambda x,y: (data_augmentation(x), y)
x = (train_ds
.take(1)
.map(func, num_parallel_calls=AUTOTUNE))
Building a model
Now we're ready to create a neural network for classifying images! We'll use what's known as transfer learning. With transfer learning, you reuse part of a pretrained model to get a headstart on a new dataset.
For this tutorial, we'll to use a model called VGG16 pretrained on ImageNet. Later, you might want to experiment with other models included with Keras. (Xception wouldn't be a bad choice.)
The distribution strategy we created earilier contains a context manager, straategy.scope. This context manager tells TensorFlow how to divide the work of training among the eight TPU cores. When using TensorFlow with a TPU, it's important to define your model in strategy.sceop() context.
EPOCHS = 12
with strategy.scope():
pretrained_model = tf.keras.applications.VGG16(
weights='imagenet',
include_top=False ,
input_shape=[*IMG_SIZE, 3]
)
pretrained_model.trainable = False
model = tf.keras.Sequential([
# To a base pretrained on ImageNet to extract features from images...
pretrained_model,
# ... attach a new head to act as a classifier.
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(len(CLASSES), activation='softmax')
])
CosineDecayRestarts function implemented in tf.keras as it seemed promising and I struggled to find the right settings (if there were any) for the ReduceLROnPlateau
EPOCHS = 12
STEPS = int(round(count_data_items(train_filenames)/BATCH_SIZE)) * EPOCHS
STEPS_PER_EPOCH = count_data_items(train_filenames) // BATCH_SIZE
schedule = tf.keras.experimental.CosineDecayRestarts(
initial_learning_rate=1e-4,
first_decay_steps=180
)
schedule.get_config()
x = [i for i in range(STEPS)]
y = [schedule(s) for s in range(STEPS)]
plt.plot(x, y)
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath='001_best_model.h5',
monitor='val_loss',
save_best_only=True),
]
model.compile(
optimizer=tf.keras.optimizers.Adam(schedule),
loss = 'sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'],
)
model.summary()
history = model.fit(
x=train_ds,
validation_data=valid_ds,
epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
callbacks=callbacks
)
def plot_hist(hist):
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Loss over epochs')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'valid'], loc='best')
plt.show()
def display_training_curves(training, validation, title, subplot):
if subplot%10==1: # set up the subplots on the first call
plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
plt.tight_layout()
ax = plt.subplot(subplot)
ax.set_facecolor('#F8F8F8')
ax.plot(training)
ax.plot(validation)
ax.set_title('model '+ title)
ax.set_ylabel(title)
#ax.set_ylim(0.28,1.05)
ax.set_xlabel('epoch')
ax.legend(['train', 'valid.'])
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
def display_confusion_matrix(cmat, score, precision, recall):
plt.figure(figsize=(15,15))
ax = plt.gca()
ax.matshow(cmat, cmap='Reds')
ax.set_xticks(range(len(CLASSES)))
ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
ax.set_yticks(range(len(CLASSES)))
ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
titlestring = ""
if score is not None:
titlestring += 'f1 = {:.3f} '.format(score)
if precision is not None:
titlestring += '\nprecision = {:.3f} '.format(precision)
if recall is not None:
titlestring += '\nrecall = {:.3f} '.format(recall)
if len(titlestring) > 0:
ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
plt.show()
def display_training_curves(training, validation, title, subplot):
if subplot%10==1: # set up the subplots on the first call
plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
plt.tight_layout()
ax = plt.subplot(subplot)
ax.set_facecolor('#F8F8F8')
ax.plot(training)
ax.plot(validation)
ax.set_title('model '+ title)
ax.set_ylabel(title)
#ax.set_ylim(0.28,1.05)
ax.set_xlabel('epoch')
ax.legend(['train', 'valid.'])
display_training_curves(
history.history['loss'],
history.history['val_loss'],
'loss',
211,
)
display_training_curves(
history.history['sparse_categorical_accuracy'],
history.history['val_sparse_categorical_accuracy'],
'accuracy',
212,
)
cmat_ds = tf.data.TFRecordDataset(valid_filenames, num_parallel_reads=AUTOTUNE)
cmat_ds = (cmat_ds
.map(collate_labeled_tfrecord)
.batch(BATCH_SIZE)
.cache()
.prefetch(AUTOTUNE))
images_ds = cmat_ds.map(lambda image, label: image)
labels_ds = cmat_ds.map(lambda image, label: label).unbatch()
cm_correct_labels = next(iter(labels_ds.batch(count_data_items(valid_filenames)))).numpy()
cm_probabilities = model.predict(images_ds)
cm_predictions = np.argmax(cm_probabilities, axis=-1)
labels = range(len(CLASSES))
cmat = confusion_matrix(
cm_correct_labels,
cm_predictions,
labels=labels,
)
cmat = (cmat.T / cmat.sum(axis=1)).T # normalize
You might be familiar with metrics like F1-score or precision and recall. This cell will compute these metrics and display them with a plot of the confusion matrix. (These metrics are defined in the Scikit-learn module sklearn.metrics; we've imported them in the helper script for you.)
score = f1_score(
cm_correct_labels,
cm_predictions,
labels=labels,
average='macro',
)
precision = precision_score(
cm_correct_labels,
cm_predictions,
labels=labels,
average='macro',
)
recall = recall_score(
cm_correct_labels,
cm_predictions,
labels=labels,
average='macro',
)
display_confusion_matrix(cmat, score, precision, recall)
Visual Validation
It can also be helpful to look at some examples from the validation set and see what class your model predicted. This can help reveal patterns in the kinds of images your model has trouble with. This cell will set up the validation set to display 20 images at a time -- you can change this to display more or fewer, if you like.
1% Better Everyday
reference
- Create Your First Submission
- How to use my own data source?
- TPU-speed data pipelines: tf.data.Dataset and TFRecords
todos
- Comment out the 1/255.0 in the image preprocessing
- Reorganize the notebook structure
done